"""
Phase 4: Flow → Outcome IV Regressions
========================================
Use gravity-predicted demographic inflows as instruments for actual inflows.
Manual 2SLS via PanelGLS with bootstrap standard errors.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

N_BOOTSTRAP = 100
np.random.seed(42)


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def run_ols(df, y_var, x_vars, label):
    """Run OLS-style PanelGLS and return results."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[[c for c in cols if c in df.columns]].dropna()
    actual_x = [v for v in x_vars if v in sub.columns]
    if len(sub) < 50 or len(actual_x) == 0:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    gls.fit(sub[y_var].values, sub[actual_x].values,
            sub['iso3'].values, sub['year'].values)

    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    row = {'model': label, 'y_var': y_var, 'n_obs': gls.n_obs,
           'r_squared': gls.r_squared, 'rho': gls.rho}
    for i, name in enumerate(actual_x):
        s = stars(gls.pvalues[i])
        print(f"    {name:35s} {gls.beta[i]:10.5f} ({gls.se[i]:.5f}) {s}")
        row[f'{name}_coef'] = gls.beta[i]
        row[f'{name}_se'] = gls.se[i]
        row[f'{name}_p'] = gls.pvalues[i]
    return row


def run_2sls(df, y_var, endog_var, instrument_var, controls, label):
    """
    Manual 2SLS via PanelGLS.
    Stage 1: endog = α + β·instrument + controls + u
    Stage 2: outcome = α + δ·fitted_endog + controls + ε
    Bootstrap SEs over 100 iterations.
    """
    all_vars = [y_var, endog_var, instrument_var] + controls + ['iso3', 'year']
    sub = df[[c for c in all_vars if c in df.columns]].dropna()
    actual_controls = [c for c in controls if c in sub.columns]
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    print(f"\n  {label}")
    print(f"  Sample: {len(sub)} obs, {sub['iso3'].nunique()} countries")

    # ── Stage 1 ──
    s1_x = [instrument_var] + actual_controls
    gls1 = PanelGLS()
    gls1.fit(sub[endog_var].values, sub[s1_x].values,
             sub['iso3'].values, sub['year'].values)

    first_stage_f = (gls1.beta[0] / gls1.se[0]) ** 2
    first_stage_p = gls1.pvalues[0]
    partial_r2 = gls1.r_squared  # approximate

    print(f"  Stage 1: {endog_var} ~ {instrument_var}")
    print(f"    Instrument coef: {gls1.beta[0]:.5f} (SE={gls1.se[0]:.5f}, p={first_stage_p:.4f})")
    print(f"    First-stage F: {first_stage_f:.2f}, R²: {gls1.r_squared:.4f}")

    # Fitted values from stage 1
    sub['fitted_endog'] = gls1.fitted

    # ── Stage 2 ──
    s2_x = ['fitted_endog'] + actual_controls
    gls2 = PanelGLS()
    gls2.fit(sub[y_var].values, sub[s2_x].values,
             sub['iso3'].values, sub['year'].values)

    iv_coef = gls2.beta[0]
    iv_se_naive = gls2.se[0]
    print(f"  Stage 2: {y_var} ~ fitted({endog_var})")
    print(f"    IV coef (naive SE): {iv_coef:.5f} ({iv_se_naive:.5f})")

    # ── Bootstrap SEs for IV ──
    print(f"  Bootstrapping ({N_BOOTSTRAP} iterations)...")
    countries = sub['iso3'].unique()
    boot_coefs = []

    for b in range(N_BOOTSTRAP):
        # Cluster bootstrap: resample countries
        boot_countries = np.random.choice(countries, size=len(countries), replace=True)
        boot_dfs = []
        for i, c in enumerate(boot_countries):
            cdf = sub[sub['iso3'] == c].copy()
            cdf['iso3'] = f"{c}_{i}"  # unique entity for resampled
            boot_dfs.append(cdf)
        boot_df = pd.concat(boot_dfs, ignore_index=True)

        try:
            # Stage 1
            g1 = PanelGLS()
            g1.fit(boot_df[endog_var].values, boot_df[s1_x].values,
                   boot_df['iso3'].values, boot_df['year'].values)
            boot_df['fitted_endog'] = g1.fitted

            # Stage 2
            g2 = PanelGLS()
            g2.fit(boot_df[y_var].values, boot_df[s2_x].values,
                   boot_df['iso3'].values, boot_df['year'].values)
            boot_coefs.append(g2.beta[0])
        except Exception:
            continue

    if len(boot_coefs) > 10:
        boot_se = np.std(boot_coefs, ddof=1)
        boot_t = iv_coef / boot_se if boot_se > 0 else 0
        from scipy import stats as scipy_stats
        boot_p = 2 * (1 - scipy_stats.t.cdf(abs(boot_t), df=len(sub) - len(s2_x) - 1))
        print(f"    IV coef (bootstrap SE): {iv_coef:.5f} ({boot_se:.5f}) p={boot_p:.4f}")
    else:
        boot_se = iv_se_naive
        boot_p = gls2.pvalues[0]
        print(f"    Bootstrap failed, using naive SE")

    row = {
        'model': label, 'y_var': y_var, 'n_obs': gls2.n_obs,
        'first_stage_F': first_stage_f,
        'first_stage_p': first_stage_p,
        'first_stage_R2': gls1.r_squared,
        'iv_coef': iv_coef,
        'iv_se_naive': iv_se_naive,
        'iv_se_bootstrap': boot_se,
        'iv_p_bootstrap': boot_p,
        'r_squared': gls2.r_squared,
    }
    return row


def main():
    print("=" * 70)
    print("PHASE 4: FLOW → OUTCOME IV REGRESSIONS")
    print("=" * 70)

    df = pd.read_csv(DATA / "deepening_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    # Check instrument availability
    has_instruments = 'log_predicted_demo_inflows' in df.columns
    has_portfolio = 'log_total_portfolio_inflows' in df.columns
    n_iv = df['log_predicted_demo_inflows'].notna().sum() if has_instruments else 0
    print(f"Instrument coverage: {n_iv} obs with predicted_demo_inflows")

    if not has_instruments or n_iv < 100:
        print("WARNING: Insufficient instrument coverage. Proceeding with available data.")

    results = []
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']

    outcomes = {
        'delta_log_kl': 'Capital Deepening',
        'gross_fixed_investment_gdp': 'Investment/GDP',
        'delta_log_tfp': 'TFP Growth',
        'rgdp_growth': 'GDP Growth',
        'mpk_proxy': 'MPK Proxy',
    }

    for y_var, y_label in outcomes.items():
        if y_var not in df.columns:
            print(f"\nSkipping {y_label}: variable missing")
            continue

        print(f"\n{'='*50}")
        print(f"Outcome: {y_label}")
        print(f"{'='*50}")

        # ── OLS: actual inflows → outcome ──
        if has_portfolio:
            r = run_ols(df, y_var,
                        ['log_total_portfolio_inflows'] + controls,
                        f'OLS: Portfolio → {y_label}')
            if r: results.append(r)

        # ── Reduced form: predicted demo inflows → outcome ──
        if has_instruments:
            r = run_ols(df, y_var,
                        ['log_predicted_demo_inflows'] + controls,
                        f'Reduced Form: Demo Inflows → {y_label}')
            if r: results.append(r)

        # ── IV: instrumented inflows → outcome ──
        if has_instruments and has_portfolio:
            r = run_2sls(df, y_var,
                         endog_var='log_total_portfolio_inflows',
                         instrument_var='log_predicted_demo_inflows',
                         controls=controls,
                         label=f'IV: Portfolio → {y_label}')
            if r: results.append(r)

    # ── Portfolio vs. FDI decomposition ──
    print(f"\n{'='*50}")
    print("Portfolio vs. FDI: Separate Instruments")
    print(f"{'='*50}")

    for flow_type in ['log_total_portfolio_inflows', 'log_total_fdi_inflows']:
        flow_label = 'Portfolio' if 'portfolio' in flow_type else 'FDI'
        if flow_type not in df.columns:
            continue
        for y_var in ['delta_log_kl', 'gross_fixed_investment_gdp']:
            if y_var not in df.columns:
                continue
            y_label = 'K/L' if 'kl' in y_var else 'Invest'
            r = run_ols(df, y_var, [flow_type] + controls,
                        f'OLS: {flow_label} → {y_label}')
            if r: results.append(r)

    # ── Save ──
    results_df = pd.DataFrame(results)
    results_df.to_csv(OUT_TABLES / "flow_outcomes_results.csv", index=False)

    # Formatted table
    with open(OUT_TABLES / "flow_outcomes.md", 'w') as f:
        f.write("# Table 4: Flow → Outcome Regressions\n\n")

        # IV results summary
        iv_results = [r for r in results if r.get('model', '').startswith('IV:')]
        if iv_results:
            f.write("## IV Results (Gravity-Predicted Demographic Inflows as Instrument)\n\n")
            f.write("| Outcome | N | 1st-F | IV β | Bootstrap SE | p |\n")
            f.write("|---------|---|-------|------|-------------|---|\n")
            for r in iv_results:
                f.write(f"| {r['model']} | {r['n_obs']} | {r['first_stage_F']:.1f} "
                        f"| {r['iv_coef']:.4f} | {r['iv_se_bootstrap']:.4f} "
                        f"| {r['iv_p_bootstrap']:.4f} |\n")
            f.write("\n")

        # OLS and reduced form
        ols_results = [r for r in results if 'OLS' in r.get('model', '') or 'Reduced' in r.get('model', '')]
        if ols_results:
            f.write("## OLS and Reduced Form\n\n")
            f.write("| Model | N | R² | Flow β | SE | p |\n")
            f.write("|-------|---|-----|--------|-----|---|\n")
            for r in ols_results:
                # Find the flow variable coefficient
                flow_var = None
                for k in r:
                    if k.endswith('_coef') and 'log_' in k:
                        flow_var = k.replace('_coef', '')
                        break
                if flow_var:
                    f.write(f"| {r['model']} | {r['n_obs']} | {r['r_squared']:.4f} "
                            f"| {r[f'{flow_var}_coef']:.4f} | {r[f'{flow_var}_se']:.4f} "
                            f"| {r[f'{flow_var}_p']:.4f} |\n")
            f.write("\n")

        f.write("*Bootstrap SEs from 100 cluster-bootstrap iterations (resampling countries).*\n")

    print("\nPhase 4 complete.")


if __name__ == '__main__':
    main()
